{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "804a617d-b335-49d2-8883-e4c3992695ee",
   "metadata": {},
   "source": [
    "Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
    "\n",
    "https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50ff8ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
    "%env CUDA_VISIBLE_DEVICES=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f34a8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "import pickle\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "from ipywidgets import HTML\n",
    "from ipyevents import Event\n",
    "from baukit import renormalize\n",
    "\n",
    "from models.layout import model_full\n",
    "from utils import sky_util, soat_util, camera_util, render_settings, show, filters\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91216085",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4cc1e6e",
   "metadata": {},
   "source": [
    "# load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "530a4d16",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model = model_full.ModelFull('pretrained/model_terrain.pkl', 'pretrained/model_sky_360.pkl').to(device).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3aeafb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_model.set_nerf_params(**render_settings.nerf_render_interactive)\n",
    "G_layout = full_model.terrain_model.layout_model\n",
    "G_soat = soat_util.init_soat_model(G_layout)\n",
    "G_sky = full_model.sky_model\n",
    "grid = sky_util.make_grid(G_sky.G)\n",
    "input_layer = G_sky.G.synthesis.input"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e45118",
   "metadata": {},
   "source": [
    "# generate initial layout and skydome env map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65820c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 944 # np.random.randint(0, 1000)\n",
    "truncation = 0.8\n",
    "grid_size = 5\n",
    "\n",
    "layout = soat_util.generate_layout(seed, grid_h=grid_size, grid_w=grid_size, device=device, truncation_psi=truncation)\n",
    "z = torch.randn(1, G_layout.layout_generator.z_dim, device=device)  \n",
    "c = None\n",
    "noise_input = torch.randn_like(layout)[:, :1]\n",
    "\n",
    "sampled_Rt = G_layout.trajectory_sampler.sample_trajectories(G_layout.layout_decoder, layout)\n",
    "initial_camera = camera_util.camera_from_pose(sampled_Rt.squeeze())\n",
    "\n",
    "print(initial_camera)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddf0303b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample initial frame\n",
    "Rt = camera_util.pose_from_camera(initial_camera)[None].to(device)\n",
    "camera_params = camera_util.get_full_image_parameters(\n",
    "    G_layout, G_layout.layout_decoder_kwargs.nerf_out_res, batch_size=1, device=device, Rt=Rt)\n",
    "outputs = full_model(z, c, camera_params, truncation=truncation, \n",
    "                     nerf_kwargs=dict(cached_layout=layout,\n",
    "                                      extras=[],\n",
    "                                      noise_input=noise_input,\n",
    "                                     )\n",
    "                    )\n",
    "\n",
    "# generate sky texture based on initial frame\n",
    "sky_encoder_ws = G_sky.encode(outputs['rgb_up'] * outputs['sky_mask'])\n",
    "sky_z = z[:, :G_sky.G.z_dim]\n",
    "start_grid = sky_util.generate_start_grid(seed, input_layer, grid)\n",
    "sky_pano = sky_util.generate_pano_transform(G_sky.G, sky_z, sky_encoder_ws, start_grid)\n",
    "sky_texture = sky_pano[None]\n",
    "# show(renormalize.as_image((outputs['rgb_up'] * outputs['sky_mask'])[0]))\n",
    "show(renormalize.as_image(sky_texture[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ecbb7a2",
   "metadata": {},
   "source": [
    "# interactive widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76602f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "l = HTML(\"\")\n",
    "h = HTML(\"\")\n",
    "d = Event(source=l, watched_events=['keydown'])\n",
    "display_size = (256, 256)\n",
    "\n",
    "camera = initial_camera\n",
    "camera_util.INITIAL_CAMERA = initial_camera\n",
    "\n",
    "# How fast we adjust. Too large and it will overshoot.\n",
    "# Too small and it will not react in time to avoid mountains.\n",
    "tilt_velocity_scale = .3    # Keep this small, otherwise you'll get motion sickness.\n",
    "offset_velocity_scale = .5\n",
    "\n",
    "# How far up the image should the horizon be, ideally.\n",
    "# Suggested range: 0.5 to 0.7.\n",
    "horizon_target = 0.65\n",
    "\n",
    "# What proportion of the depth map should be \"near\" the camera, ideally.\n",
    "# The smaller the number, the higher up the camera will fly.\n",
    "# Suggested range: 0.05 to 0.2\n",
    "near_target = 0.2\n",
    "\n",
    "offset = 0\n",
    "tilt = 0\n",
    "initial_stabilize_frames = 10\n",
    "\n",
    "def generate_frame_from_camera(camera):\n",
    "    Rt = camera_util.pose_from_camera(camera)[None].to(device)\n",
    "    camera_params = camera_util.get_full_image_parameters(\n",
    "        G_layout, G_layout.layout_decoder_kwargs.nerf_out_res, batch_size=1, device=device, Rt=Rt)\n",
    "    outputs = full_model(z, c, camera_params, truncation=truncation, \n",
    "                         nerf_kwargs=dict(extras = ['camera_points'],\n",
    "                                          cached_layout=layout,\n",
    "                                          noise_input=noise_input, \n",
    "                                         ),\n",
    "                         sky_texture=sky_texture\n",
    "                        )\n",
    "    return outputs\n",
    "    \n",
    "def update_display(outputs, camera):\n",
    "    composite_rgb_url = renormalize.as_url(outputs['rgb_overlay_upsample'][0], size=display_size)\n",
    "\n",
    "    vis_rays =  camera_util.visualize_rays(G_layout, outputs['extras']['Rt'], outputs['extras']['camera_points'],\n",
    "                                           outputs['extras']['layout'], display_size[0])\n",
    "    cam_img = renormalize.as_image(vis_rays)    \n",
    "    cam_url = renormalize.as_url(cam_img, size=display_size)\n",
    "    img_html = ('<div class=\"row\"> <img src=\"%s\"/> <img src=\"%s\"/> </div>' % (composite_rgb_url, cam_url))\n",
    "    l.value = img_html\n",
    "    h.value = str(camera)\n",
    "    \n",
    "\n",
    "def handle_event(event):\n",
    "    global camera, offset, tilt\n",
    "    camera = camera_util.update_camera(camera, event['key'], auto_adjust_height_and_tilt=True)\n",
    "    c = camera_util.adjust_camera_vertically(camera, offset, tilt)\n",
    "    outputs = generate_frame_from_camera(c)\n",
    "    outputs = filters.smooth_mask(outputs) # optional mask smoothing\n",
    "    update_display(outputs, c)\n",
    "    tilt, offset = camera_util.update_tilt_and_offset(outputs, tilt, offset,\n",
    "                                                      horizon_target=horizon_target,\n",
    "                                                      near_target=near_target, \n",
    "                                                      tilt_velocity_scale=tilt_velocity_scale,\n",
    "                                                      offset_velocity_scale=offset_velocity_scale)                                                    \n",
    "\n",
    "    \n",
    "for x in range(initial_stabilize_frames):\n",
    "    outputs = generate_frame_from_camera(camera_util.adjust_camera_vertically(camera, offset, tilt))\n",
    "    tilt, offset = camera_util.update_tilt_and_offset(outputs, tilt, offset, \n",
    "                                                      horizon_target=horizon_target,\n",
    "                                                      near_target=near_target, \n",
    "                                                      tilt_velocity_scale=tilt_velocity_scale,\n",
    "                                                      offset_velocity_scale=offset_velocity_scale)\n",
    "\n",
    "d.on_dom_event(handle_event)\n",
    "display(h, l)\n",
    "handle_event({'key': 'x'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c958a58",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "persistentnature",
   "language": "python",
   "name": "persistentnature"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
